package com.tomz.xatomic.xa;

import com.atomikos.jdbc.AtomikosDataSourceBean;
import com.tomz.xatomic.Constant;
import com.tomz.xatomic.DynamicDataSourceInitializer;
import org.apache.ibatis.mapping.DatabaseIdProvider;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ExecutorType;
import org.apache.ibatis.session.SqlSessionFactory;
import org.mybatis.spring.SqlSessionFactoryBean;
import org.mybatis.spring.SqlSessionTemplate;
import org.mybatis.spring.annotation.MapperScan;
import org.mybatis.spring.boot.autoconfigure.ConfigurationCustomizer;
import org.mybatis.spring.boot.autoconfigure.MybatisAutoConfiguration;
import org.mybatis.spring.boot.autoconfigure.MybatisProperties;
import org.mybatis.spring.boot.autoconfigure.SpringBootVFS;
import org.mybatis.spring.mapper.MapperFactoryBean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.context.properties.bind.Bindable;
import org.springframework.boot.context.properties.bind.Binder;
import org.springframework.boot.context.properties.source.ConfigurationPropertyName;
import org.springframework.boot.context.properties.source.ConfigurationPropertySource;
import org.springframework.boot.context.properties.source.MapConfigurationPropertySource;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.core.env.Environment;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;
import org.springframework.util.*;

import javax.sql.DataSource;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 启用分布式支持说明 ：<br>
 * <p>
 * 需使用MapperScan注解并指定factoryBean为{@link XaDynamicMapperFactoryBean}和扫描的basePackages信息.<br/>
 * 无MapperScan默认采用{@link MybatisAutoConfiguration.MapperScannerRegistrarNotFoundConfiguration}配置,
 * {@link MybatisAutoConfiguration.AutoConfiguredMapperScannerRegistrar}扫描Mapper，
 * factoryBean默认采用MapperFactoryBean无法实现Mapper映射到所有数据源<br/>
 * e.g：<br>
 * <p>
 *     {@link MapperScan @MapperScan}(factoryBean = XaDynamicMapperFactoryBean.class, basePackages = {"com.abc.mapper"})<br>
 *     {@link org.springframework.boot.autoconfigure.SpringBootApplication @SpringBootApplication}<br>
 *     class ApplicationDemo {
 *         ..
 *     }
 * </p>
 * </p>
 * @see MybatisAutoConfiguration
 * @see org.springframework.boot.autoconfigure.transaction.jta.JtaAutoConfiguration
 * @author ZHUFEIFEI
 */
@Import({XaDynamicDataSourceRegistrar.class})
@ConditionalOnClass({AtomikosDataSourceBean.class, MapperFactoryBean.class})
@ConditionalOnProperty(prefix = "spring.jta", name = "enabled", matchIfMissing = true)
@EnableConfigurationProperties({MybatisProperties.class})
@org.springframework.context.annotation.Configuration
public class XaDynamicDataSourceConfiguration implements InitializingBean , EnvironmentAware, DynamicDataSourceInitializer {

    private final Logger log = LoggerFactory.getLogger(getClass());
    private Environment environment;
    private MybatisProperties properties;
    private Interceptor[] interceptors;
    private ResourceLoader resourceLoader;
    private DatabaseIdProvider databaseIdProvider;
    private List<ConfigurationCustomizer> configurationCustomizers;

    public XaDynamicDataSourceConfiguration(MybatisProperties properties,
                                            ObjectProvider<Interceptor[]> interceptorsProvider,
                                            ResourceLoader resourceLoader,
                                            ObjectProvider<DatabaseIdProvider> databaseIdProvider,
                                            ObjectProvider<List<ConfigurationCustomizer>> configurationCustomizersProvider) {
        this.properties = properties;
        this.interceptors = interceptorsProvider.getIfAvailable();
        this.resourceLoader = resourceLoader;
        this.databaseIdProvider = databaseIdProvider.getIfAvailable();
        this.configurationCustomizers = configurationCustomizersProvider.getIfAvailable();
    }

    @Bean
    public SqlSessionFactory sqlSessionFactory(AbstractRoutingDataSource dataSource) throws Exception {
        XaDynamicSqlSessionFactory sqlSessionFactory = new XaDynamicSqlSessionFactory();
        Map<String, DataSource> dataSources = extractDataSource(dataSource);
        Assert.isTrue(!dataSources.isEmpty(), "AbstractRoutingDataSource targetDataSources can not be empty!");
        Binder binder = Binder.get(this.environment);
        Map configProperties = binder.bind(Constant.MYBATIS_CONFIG_PREFIX, Bindable.of(Map.class)).orElseGet(() -> new HashMap(1));
        for (Map.Entry<String, DataSource> kv : dataSources.entrySet()) {
            sqlSessionFactory.add(kv.getKey(), buildSqlSessionFactory(kv.getKey(), kv.getValue(), configProperties));
            log.debug("Create SqlSessionFactory : {}", kv.getKey());
        }
        return sqlSessionFactory;
    }

    @Bean
    public SqlSessionTemplate sqlSessionTemplate(SqlSessionFactory sqlSessionFactory) {
        ExecutorType executorType = this.properties.getExecutorType();
        Assert.isTrue(sqlSessionFactory.getClass().isAssignableFrom(XaDynamicSqlSessionFactory.class), "Not support SqlSessionFactory type!");
        XaDynamicSqlSessionFactory factory = XaDynamicSqlSessionFactory.class.cast(sqlSessionFactory);
        XaDynamicSqlSessionTemplate template = new XaDynamicSqlSessionTemplate(sqlSessionFactory);
        for (Map.Entry<String, SqlSessionFactory> kv : factory.sqlSessionFactories().entrySet()) {
            if (executorType != null) {
                template.add(kv.getKey(), new SqlSessionTemplate(kv.getValue(), executorType));
            } else {
                template.add(kv.getKey(), new SqlSessionTemplate(kv.getValue()));
            }
            log.debug("Create SqlSessionTemplate : {}", kv.getKey());
        }
        return template;
    }

    @Override
    public void afterPropertiesSet() {
        checkConfigFileExists();
    }

    @Override
    public void setEnvironment(Environment environment) {
        this.environment = environment;
    }

    private void checkConfigFileExists() {
        if (this.properties.isCheckConfigLocation() && StringUtils.hasText(this.properties.getConfigLocation())) {
            Resource resource = this.resourceLoader.getResource(this.properties.getConfigLocation());
            Assert.state(resource.exists(), "Cannot find config location: " + resource
                    + " (please add config file or check your Mybatis configuration)");
        }
    }

    private Map<String,DataSource> extractDataSource(AbstractRoutingDataSource dataSource) {
        Field f = ReflectionUtils.findField(dataSource.getClass(), "targetDataSources");
        f.setAccessible(true);
        return (Map<String, DataSource>) ReflectionUtils.getField(f, dataSource);
    }

    private SqlSessionFactory buildSqlSessionFactory(String key, DataSource dataSource, Map configProperties) throws Exception {
        SqlSessionFactoryBean factory = new SqlSessionFactoryBean();
        //
        factory.setEnvironment(SqlSessionFactoryBean.class.getSimpleName() + "#" + key);
        factory.setDataSource(dataSource);
        factory.setVfs(SpringBootVFS.class);
        if (StringUtils.hasText(this.properties.getConfigLocation())) {
            factory.setConfigLocation(this.resourceLoader.getResource(this.properties.getConfigLocation()));
        }
        applyConfiguration(factory, configProperties);
        if (this.properties.getConfigurationProperties() != null) {
            factory.setConfigurationProperties(this.properties.getConfigurationProperties());
        }
        if (!ObjectUtils.isEmpty(this.interceptors)) {
            factory.setPlugins(this.interceptors);
        }
        if (this.databaseIdProvider != null) {
            factory.setDatabaseIdProvider(this.databaseIdProvider);
        }
        if (StringUtils.hasLength(this.properties.getTypeAliasesPackage())) {
            factory.setTypeAliasesPackage(this.properties.getTypeAliasesPackage());
        }
        if (this.properties.getTypeAliasesSuperType() != null) {
            factory.setTypeAliasesSuperType(this.properties.getTypeAliasesSuperType());
        }
        if (StringUtils.hasLength(this.properties.getTypeHandlersPackage())) {
            factory.setTypeHandlersPackage(this.properties.getTypeHandlersPackage());
        }
        if (!ObjectUtils.isEmpty(this.properties.resolveMapperLocations())) {
            factory.setMapperLocations(this.properties.resolveMapperLocations());
        }
        return factory.getObject();
    }


    private void applyConfiguration(SqlSessionFactoryBean factory, Map configProperties) throws IOException, ClassNotFoundException {
        //否则多个sqlSessionFactory会共用一个，导致dataSource最后一个覆盖之前所有的
        Configuration configuration;
        if (configProperties.isEmpty()) {
            configuration = new Configuration();
        } else {
            configuration = this.bind(Configuration.class, configProperties);
        }
        if (!CollectionUtils.isEmpty(this.configurationCustomizers)) {
            for (ConfigurationCustomizer customizer : this.configurationCustomizers) {
                customizer.customize(configuration);
            }
        }
        factory.setConfiguration(configuration);
    }

    private <T> T bind(Class<T> clazz, Map properties) {
        ConfigurationPropertySource source = new MapConfigurationPropertySource(properties);
        Binder binder = new Binder(new ConfigurationPropertySource[]{source});
        return binder.bind(ConfigurationPropertyName.EMPTY, Bindable.of(clazz)).get();
    }

}
